linalg_lu.f90 Source File


Source Code

module linalg_lu
    use iso_fortran_env
    use linalg_errors
    use linalg_sparse
    use linalg_basic
    use ferror
    use lapack
    use sparskit
    implicit none
    private
    public :: lu_factor
    public :: form_lu
    public :: solve_lu

    interface lu_factor
        module procedure :: lu_factor_dbl
        module procedure :: lu_factor_cmplx
        module procedure :: csr_lu_factor
    end interface

    interface form_lu
        module procedure :: form_lu_all
        module procedure :: form_lu_all_cmplx
        module procedure :: form_lu_only
        module procedure :: form_lu_only_cmplx
    end interface

    interface solve_lu
        module procedure :: solve_lu_mtx
        module procedure :: solve_lu_mtx_cmplx
        module procedure :: solve_lu_vec
        module procedure :: solve_lu_vec_cmplx
        module procedure :: csr_lu_solve
    end interface
contains
! ------------------------------------------------------------------------------
subroutine lu_factor_dbl(a, ipvt, err)
    !! Computes the LU factorization of an M-by-N matrix.
    real(real64), intent(inout), dimension(:,:) :: a
        !! On input, the M-by-N matrix on which to operate.  On output, the 
        !! LU factored matrix in the form [L\\U] where the unit diagonal
        !! elements of L are not stored.
    integer(int32), intent(out), dimension(:) :: ipvt
        !! An MIN(M, N)-element array used to track row-pivot operations.  The 
        !! array stored pivot information such that row I is interchanged with 
        !! row IPVT(I).
    class(errors), intent(inout), optional, target :: err
        !! An error object to report any errors that occur.

    ! Local Variables
    integer(int32) :: m, n, mn, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Initialization
    m = size(a, 1)
    n = size(a, 2)
    mn = min(m, n)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    flag = 0
    if (size(ipvt) /= mn) then
        ! ERROR: IPVT not sized correctly
        call report_array_size_error("lu_factor_dbl", errmgr, "ipvt", mn, &
            size(ipvt))
        return
    end if

    ! Compute the LU factorization by calling the LAPACK routine DGETRF
    call DGETRF(m, n, a, m, ipvt, flag)

    ! If flag > 0, the matrix is singular.  Notice, flag should not be
    ! able to be < 0 as we've already verrified inputs prior to making the
    ! call to LAPACK
    if (flag > 0) then
        ! WARNING: Singular matrix
        call report_singular_matrix_warning("lu_factor_dbl", errmgr, flag)
    end if
end subroutine

! ------------------------------------------------------------------------------
subroutine lu_factor_cmplx(a, ipvt, err)
    !! Computes the LU factorization of an M-by-N matrix.
    complex(real64), intent(inout), dimension(:,:) :: a
        !! On input, the M-by-N matrix on which to operate.  On output, the 
        !! LU factored matrix in the form [L\\U] where the unit diagonal
        !! elements of L are not stored.
    integer(int32), intent(out), dimension(:) :: ipvt
        !! An MIN(M, N)-element array used to track row-pivot operations.  The 
        !! array stored pivot information such that row I is interchanged with 
        !! row IPVT(I).
    class(errors), intent(inout), optional, target :: err
        !! An error object to report any errors that occur.

    ! Local Variables
    integer(int32) :: m, n, mn, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    character(len = :), allocatable :: errmsg

    ! Initialization
    m = size(a, 1)
    n = size(a, 2)
    mn = min(m, n)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    flag = 0
    if (size(ipvt) /= mn) then
        ! ERROR: IPVT not sized correctly
        call errmgr%report_error("lu_factor_cmplx", &
            "Incorrectly sized input array IPVT, argument 2.", &
            LA_ARRAY_SIZE_ERROR)
        return
    end if

    ! Compute the LU factorization by calling the LAPACK routine ZGETRF
    call ZGETRF(m, n, a, m, ipvt, flag)

    ! If flag > 0, the matrix is singular.  Notice, flag should not be
    ! able to be < 0 as we've already verrified inputs prior to making the
    ! call to LAPACK
    if (flag > 0) then
        ! WARNING: Singular matrix
        call report_singular_matrix_warning("lu_factor_cmplx", errmgr, flag)
    end if
end subroutine

! ------------------------------------------------------------------------------
subroutine csr_lu_factor(a, lu, ju, droptol, err)
    !! Factors a matrix using an LU decomposition.
    class(csr_matrix), intent(in) :: a
        !! The matrix to factor.
    type(msr_matrix), intent(out) :: lu
        !! The LU matrix.
    integer(int32), intent(out), dimension(:) :: ju
        !! The row tracking array.
    real(real64), intent(in), optional :: droptol
        !! The drop tolerance for the ILUT factorization.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: i, m, n, nn, nnz, lfil, iwk, ierr, flag
    integer(int32), allocatable, dimension(:) :: jlu, jw
    real(real64), allocatable, dimension(:) :: alu, w
    real(real64) :: dt
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    
    ! Initialization
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if
    if (present(droptol)) then
        dt = droptol
    else
        dt = sqrt(epsilon(dt))
    end if
    m = size(a, 1)
    n = size(a, 2)
    nnz = nonzero_count(a)

    ! Input Check
    if (size(ju) /= m) then
        call report_array_size_error("csr_lu_factor", errmgr, "ju", m, size(ju))
        return
    end if

    ! Parameter Determination
    lfil = 1
    do i = 1, m
        lfil = max(lfil, a%row_indices(i+1) - a%row_indices(i))
    end do
    iwk = max(lfil * m, nnz)  ! somewhat arbitrary - can be adjusted

    ! Local Memory Allocation
    allocate(alu(iwk), w(n+1), jlu(iwk), jw(2 * n), stat = flag)
    if (flag /= 0) go to 10

    ! Factorization
    do
        ! Factor the matrix
        call ilut(n, a%values, a%column_indices, a%row_indices, lfil, dt, &
            alu, jlu, ju, iwk, w, jw, ierr)

        ! Check the error flag
        if (ierr == 0) then
            ! Success
            exit
        else if (ierr > 0) then
            ! Zero pivot
        else if (ierr == -1) then
            ! The input matrix is not formatted correctly
            go to 20
        else if (ierr == -2 .or. ierr == -3) then
            ! ALU and JLU are too small - try something larger
            ! This is the main reason for the loop - to offload worrying about
            ! workspace size from the user
            iwk = min(iwk + m + n, m * n)
            deallocate(alu)
            deallocate(jlu)
            allocate(alu(iwk), jlu(iwk), stat = flag)
            if (flag /= 0) go to 10
        else if (ierr == -4) then
            ! Illegal value for LFIL - reset and try again
            lfil = n
        else if (ierr == -5) then
            ! Zero row encountered
            go to 30
        else
            ! We should never get here, but just in case
            go to 40
        end if
    end do

    ! Determine the actual number of non-zero elements
    nnz = jlu(m+1) - 1

    ! Copy the contents to the output arrays
    lu%m = m
    lu%n = n
    lu%nnz = nnz
    nn = m + 1 + nnz - min(m, n)
    allocate(lu%values(nn), source = alu(:nn), stat = flag)
    if (flag /= 0) go to 10
    allocate(lu%indices(nn), source = jlu(:nn), stat = flag)

    ! End
    return

    ! Memory Error
10  continue
    call report_memory_error("csr_lu_factor", errmgr, flag) 
    return

    ! Matrix Format Error
20  continue
    call errmgr%report_error("csr_lu_factor", &
        "The input matrix was incorrectly formatted.  A row with more " // &
        "than N entries was found.", LA_MATRIX_FORMAT_ERROR)
    return

    ! Zero Row Error
30  continue
    call errmgr%report_error("csr_lu_factor", &
        "A row with all zeros was encountered in the matrix.", &
        LA_SINGULAR_MATRIX_ERROR)
    return

    ! Unknown Error
40  continue
    call errmgr%report_error("csr_solve_sparse_direct", "ILUT encountered " // &
        "an unknown error.  The error code from the ILUT routine is " // &
        "provided in the output.", ierr)
    return

    ! Zero Pivot Error
50  continue
    call errmgr%report_error("csr_lu_factor", &
        "A zero pivot was encountered.", LA_SINGULAR_MATRIX_ERROR)
    return
end subroutine

! ------------------------------------------------------------------------------
subroutine form_lu_all(lu, ipvt, u, p, err)
    !! Extracts the L and U matrices from the condensed [L\\U] storage format 
    !! used by the lu_factor.
    real(real64), intent(inout), dimension(:,:) :: lu
        !! On input, the N-by-N matrix as output by lu_factor.  On output, the 
        !! N-by-N lower triangular matrix L.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    real(real64), intent(out), dimension(:,:) :: u
        !! An N-by-N matrix where the U matrix will be written.
    real(real64), intent(out), dimension(:,:) :: p
        !! An N-by-N matrix where the row permutation matrix will be written.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: j, jp, n
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Parameters
    real(real64), parameter :: zero = 0.0d0
    real(real64), parameter :: one = 1.0d0

    ! Initialization
    n = size(lu, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(lu, 2) /= n) then
        call report_square_matrix_error("form_lu_all", errmgr, "lu", n, &
            size(lu, 1), size(lu, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("form_lu_all", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(u, 1) /= n .or. size(u, 2) /= n) then
        call report_matrix_size_error("form_lu_all", errmgr, "u", n, n, &
            size(u, 1), size(u, 2))
        return
    else if (size(p, 1) /= n .or. size(p, 2) /= n) then
        call report_matrix_size_error("form_lu_all", errmgr, "p", n, n, &
            size(p, 1), size(p, 2))
        return
    end if

    ! Ensure P starts off as an identity matrix
    call DLASET('A', n, n, zero, one, p, n)

    ! Process
    do j = 1, n
        ! Define the pivot matrix
        jp = ipvt(j)
        if (j /= jp) call swap(p(j,1:n), p(jp,1:n))

        ! Build L and U
        u(1:j,j) = lu(1:j,j)
        u(j+1:n,j) = zero

        if (j > 1) lu(1:j-1,j) = zero
        lu(j,j) = one
    end do
end subroutine

! ------------------------------------------------------------------------------
subroutine form_lu_all_cmplx(lu, ipvt, u, p, err)
    !! Extracts the L and U matrices from the condensed [L\\U] storage format 
    !! used by the lu_factor.
    complex(real64), intent(inout), dimension(:,:) :: lu
        !! On input, the N-by-N matrix as output by lu_factor.  On output, the 
        !! N-by-N lower triangular matrix L.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    complex(real64), intent(out), dimension(:,:) :: u
        !! An N-by-N matrix where the U matrix will be written.
    real(real64), intent(out), dimension(:,:) :: p
        !! An N-by-N matrix where the row permutation matrix will be written.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: j, jp, n
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Parameters
    real(real64), parameter :: zero = 0.0d0
    real(real64), parameter :: one = 1.0d0
    complex(real64), parameter :: c_zero = (0.0d0, 0.0d0)
    complex(real64), parameter :: c_one = (1.0d0, 0.0d0)

    ! Initialization
    n = size(lu, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(lu, 2) /= n) then
        call report_square_matrix_error("form_lu_all_cmplx", errmgr, "lu", n, &
            size(lu, 1), size(lu, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("form_lu_all_cmplx", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(u, 1) /= n .or. size(u, 2) /= n) then
        call report_matrix_size_error("form_lu_all_cmplx", errmgr, "u", n, n, &
            size(u, 1), size(u, 2))
        return
    else if (size(p, 1) /= n .or. size(p, 2) /= n) then
        call report_matrix_size_error("form_lu_all_cmplx", errmgr, "p", n, n, &
            size(p, 1), size(p, 2))
        return
    end if

    ! Ensure P starts off as an identity matrix
    call DLASET('A', n, n, zero, one, p, n)

    ! Process
    do j = 1, n
        ! Define the pivot matrix
        jp = ipvt(j)
        if (j /= jp) call swap(p(j,1:n), p(jp,1:n))

        ! Build L and U
        u(1:j,j) = lu(1:j,j)
        u(j+1:n,j) = c_zero

        if (j > 1) lu(1:j-1,j) = c_zero
        lu(j,j) = c_one
    end do
end subroutine
! ------------------------------------------------------------------------------
subroutine form_lu_only(lu, u, err)
    !! Extracts the L and U matrices from the condensed [L\\U] storage format 
    !! used by the lu_factor.
    real(real64), intent(inout), dimension(:,:) :: lu
        !! On input, the N-by-N matrix as output by lu_factor.  On output, the 
        !! N-by-N lower triangular matrix L.
    real(real64), intent(out), dimension(:,:) :: u
        !! An N-by-N matrix where the U matrix will be written.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: j, n
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Parameters
    real(real64), parameter :: zero = 0.0d0
    real(real64), parameter :: one = 1.0d0

    ! Initialization
    n = size(lu, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(lu, 2) /= n) then
        call report_square_matrix_error("form_lu_only", errmgr, "lu", n, &
            size(lu, 1), size(lu, 2))
        return
    else if (size(u, 1) /= n .or. size(u, 2) /= n) then
        call report_matrix_size_error("form_lu_only", errmgr, "u", n, n, &
            size(u, 1), size(u, 2))
        return
    end if

    ! Process
    do j = 1, n
        ! Build L and U
        u(1:j,j) = lu(1:j,j)
        u(j+1:n,j) = zero

        if (j > 1) lu(1:j-1,j) = zero
        lu(j,j) = one
    end do
end subroutine

! ------------------------------------------------------------------------------
subroutine form_lu_only_cmplx(lu, u, err)
    !! Extracts the L and U matrices from the condensed [L\\U] storage format 
    !! used by the lu_factor.
    complex(real64), intent(inout), dimension(:,:) :: lu
        !! On input, the N-by-N matrix as output by lu_factor.  On output, the 
        !! N-by-N lower triangular matrix L.
    complex(real64), intent(out), dimension(:,:) :: u
        !! An N-by-N matrix where the U matrix will be written.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: j, n
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Parameters
    complex(real64), parameter :: zero = (0.0d0, 0.0d0)
    complex(real64), parameter :: one = (1.0d0, 0.0d0)

    ! Initialization
    n = size(lu, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(lu, 2) /= n) then
        call report_square_matrix_error("form_lu_only_cmplx", errmgr, "lu", n, &
            size(lu, 1), size(lu, 2))
        return
    else if (size(u, 1) /= n .or. size(u, 2) /= n) then
        call report_matrix_size_error("form_lu_only_cmplx", errmgr, "u", n, n, &
            size(u, 1), size(u, 2))
        return
    end if

    ! Process
    do j = 1, n
        ! Build L and U
        u(1:j,j) = lu(1:j,j)
        u(j+1:n,j) = zero

        if (j > 1) lu(1:j-1,j) = zero
        lu(j,j) = one
    end do
end subroutine

! ------------------------------------------------------------------------------
subroutine solve_lu_mtx(a, ipvt, b, err)
    !! Solves a system of LU-factored equations.
    real(real64), intent(in), dimension(:,:) :: a
        !! The N-by-N LU factored matrix as output by lu_factor.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    real(real64), intent(inout), dimension(:,:) :: b
        !! On input, the N-by-NRHS right-hand-side matrix.  On output, the 
        !! N-by-NRHS solution matrix.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: n, nrhs, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr

    ! Initialization
    n = size(a, 1)
    nrhs = size(b, 2)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(a, 2) /= n) then
        call report_square_matrix_error("solve_lu_mtx", errmgr, "a", n, &
            size(a, 1), size(a, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("solve_lu_mtx", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(b, 1) /= n) then
        call report_matrix_size_error("solve_lu_mtx", errmgr, "b", n, &
            size(b, 2), size(b, 1), size(b, 2))
        return
    end if

    ! Call DGETRS
    call DGETRS("N", n, nrhs, a, n, ipvt, b, n, flag)
end subroutine

! ------------------------------------------------------------------------------
subroutine solve_lu_mtx_cmplx(a, ipvt, b, err)
    !! Solves a system of LU-factored equations.
    complex(real64), intent(in), dimension(:,:) :: a
        !! The N-by-N LU factored matrix as output by lu_factor.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    complex(real64), intent(inout), dimension(:,:) :: b
        !! On input, the N-by-NRHS right-hand-side matrix.  On output, the 
        !! N-by-NRHS solution matrix.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: n, nrhs, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    character(len = :), allocatable :: errmsg

    ! Initialization
    n = size(a, 1)
    nrhs = size(b, 2)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(a, 2) /= n) then
        call report_square_matrix_error("solve_lu_mtx_cmplx", errmgr, "a", n, &
            size(a, 1), size(a, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("solve_lu_mtx_cmplx", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(b, 1) /= n) then
        call report_matrix_size_error("solve_lu_mtx_cmplx", errmgr, "b", n, &
            size(b, 2), size(b, 1), size(b, 2))
        return
    end if

    ! Call ZGETRS
    call ZGETRS("N", n, nrhs, a, n, ipvt, b, n, flag)
end subroutine

! ------------------------------------------------------------------------------
subroutine solve_lu_vec(a, ipvt, b, err)
    !! Solves a system of LU-factored equations.
    real(real64), intent(in), dimension(:,:) :: a
        !! The N-by-N LU factored matrix as output by lu_factor.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    real(real64), intent(inout), dimension(:) :: b
        !! The N-element right-hand-side array.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: n, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    character(len = :), allocatable :: errmsg

    ! Initialization
    n = size(a, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(a, 2) /= n) then
        call report_square_matrix_error("solve_lu_vec", errmgr, "a", n, &
            size(a, 1), size(a, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("solve_lu_vec", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(b) /= n) then
        call report_array_size_error("solve_lu_vec", errmgr, "b", n, &
            size(b))
        return
    end if

    ! Call DGETRS
    call DGETRS("N", n, 1, a, n, ipvt, b, n, flag)
end subroutine

! ------------------------------------------------------------------------------
subroutine solve_lu_vec_cmplx(a, ipvt, b, err)
    !! Solves a system of LU-factored equations.
    complex(real64), intent(in), dimension(:,:) :: a
        !! The N-by-N LU factored matrix as output by lu_factor.
    integer(int32), intent(in), dimension(:) :: ipvt
        !! The N-element pivot array as output by lu_factor.
    complex(real64), intent(inout), dimension(:) :: b
        !! The N-element right-hand-side array.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: n, flag
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    character(len = :), allocatable :: errmsg

    ! Initialization
    n = size(a, 1)
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if

    ! Input Check
    if (size(a, 2) /= n) then
        call report_square_matrix_error("solve_lu_vec_cmplx", errmgr, "a", n, &
            size(a, 1), size(a, 2))
        return
    else if (size(ipvt) /= n) then
        call report_array_size_error("solve_lu_vec_cmplx", errmgr, "ipvt", n, &
            size(ipvt))
        return
    else if (size(b) /= n) then
        call report_array_size_error("solve_lu_vec_cmplx", errmgr, "b", n, &
            size(b))
        return
    end if

    ! Call ZGETRS
    call ZGETRS("N", n, 1, a, n, ipvt, b, n, flag)
end subroutine

! ------------------------------------------------------------------------------
subroutine csr_lu_solve(lu, ju, b, x, err)
    !! Solves a linear system using an LU decomposition.
    class(msr_matrix), intent(in) :: lu
        !! The LU matrix.
    integer(int32), intent(in), dimension(:) :: ju
        !! The row tracking array.
    real(real64), intent(in), dimension(:) :: b
        !! The right-hand side.
    real(real64), intent(out), dimension(:) :: x
        !! The solution.
    class(errors), intent(inout), optional, target :: err
        !! The error object to be updated.

    ! Local Variables
    integer(int32) :: m, n
    class(errors), pointer :: errmgr
    type(errors), target :: deferr
    
    ! Initialization
    if (present(err)) then
        errmgr => err
    else
        errmgr => deferr
    end if
    m = size(lu, 1)
    n = size(lu, 2)

    ! Input Check
    if (m /= n) then
        call report_square_matrix_error("csr_lu_solve", errmgr, "lu", m, m, n)
        return
    end if
    if (size(x) /= m) then
        call report_inner_matrix_dimension_error("csr_lu_solve", errmgr, &
            "lu", "x", m, size(x))
        return
    end if
    if (size(b) /= m) then
        call report_array_size_error("csr_lu_solve", errmgr, "b", m, size(b))
        return
    end if
    if (size(ju) /= m) then
        call report_array_size_error("csr_lu_solve", errmgr, "ju", m, size(ju))
        return
    end if

    ! Process
    call lusol(m, b, x, lu%values, lu%indices, ju)
end subroutine

! ------------------------------------------------------------------------------
end module